package spark.movie;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConversions;
import scala.collection.Seq;
import spark.domain.MovieInfo;
import spark.domain.RatingsInfo;
import spark.domain.UserInfo;
import spark.util.SparkUtil;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 描述:
 * 使用dataSet或者dataFrame来实现movie common
 *
 * @author jiantao7
 * @create 2018-05-17 15:51
 */
public class FilmReviewSystemThroughDataFrameOrDataSet {
    public static void main(String[] args) {
        SparkUtil sparkUtil = new SparkUtil("movieCommonAnalisys", true);
        SparkSession spark = sparkUtil.getSparkSession();

        Encoder<UserInfo> userEncoder = Encoders.bean(UserInfo.class);
        Encoder<MovieInfo> movieEncoder = Encoders.bean(MovieInfo.class);
        Encoder<RatingsInfo> ratingEncoder = Encoders.bean(RatingsInfo.class);
        /**
         *user RDD
         *UserID:: Gender:: Age:: Occupation:: Zip-code
         * 给user数据添加schema 使用Encoder
         * 也可以使用spark.createDataFrame(rowRDD: JavaRDD[Row], schema: StructType)来添加schema信息
         */
        Dataset<UserInfo> userDataSet = spark.read().text("E:\\data\\ml-1m\\users.dat").map(new MapFunction<Row, UserInfo>() {
            @Override
            public UserInfo call(Row value) throws Exception {
                String lines = value.toString();
                String substring = lines.substring(1, lines.length() - 1);
                String[] strings = substring.split("::");
                UserInfo userInfo = new UserInfo();
                userInfo.setUserId(strings[0].trim());
                userInfo.setGender(strings[1].trim());
                userInfo.setAge(Integer.parseInt(strings[2].trim()));
                userInfo.setOccupation(strings[3].trim());
                userInfo.setZipCode(strings[4].trim());
                return userInfo;
            }
        }, userEncoder);
        /**
         * movies RDD
         * MovieID:: Title:: Genres
         */
        Dataset<MovieInfo> movieDataSet = spark.read().text("E:\\data\\ml-1m\\movies.dat").map(new MapFunction<Row, MovieInfo>() {
            @Override
            public MovieInfo call(Row value) throws Exception {
                String lines = value.toString();
                String substring = lines.substring(1, lines.length() - 1);
                String[] strings = substring.split("::");
                MovieInfo movieInfo = new MovieInfo();
                movieInfo.setMovieId(strings[0]);
                movieInfo.setMovieName(strings[1]);
                movieInfo.setGenres(strings[2]);
                return movieInfo;
            }
        }, movieEncoder);
        /**
         * ratings RDD
         * UserID:: MovieID:: Rating:: Timestamp
         */
        Dataset<RatingsInfo> ratingsDataSet = spark.read().text("E:\\data\\ml-1m\\ratings.dat").map(new MapFunction<Row, RatingsInfo>() {
            @Override
            public RatingsInfo call(Row value) throws Exception {
                String lines = value.toString();
                String substring = lines.substring(1, lines.length() - 1);
                String[] strings = substring.split("::");
                RatingsInfo ratingsInfo = new RatingsInfo();
                ratingsInfo.setUserId(strings[0]);
                ratingsInfo.setMovieId(strings[1]);
                ratingsInfo.setRating(Integer.parseInt(strings[2]));
                ratingsInfo.setTimestamp(Long.parseLong(strings[3]));
                return ratingsInfo;
            }
        }, ratingEncoder);


        Dataset<UserInfo> userDS = userDataSet.as(userEncoder);
//        spark.createDataFrame()
        Dataset<MovieInfo> movieDS = movieDataSet.as(movieEncoder);
        Dataset<RatingsInfo> ratingsDS = ratingsDataSet.as(ratingEncoder);
        String movieId = "1193";
//        getFemaleAndMaleSeeMoviesCountBySparkSql(spark, userDataSet, movieDataSet, ratingsDataSet, movieId);
//        getFemaleAndMaleSeeMoviesCountByDfOp(spark, userDataSet, movieDataSet, ratingsDataSet, movieId);
//        getFemaleAndMaleSeeMovieCountByOnlySql(spark,userDataSet,ratingsDataSet,movieId);
        Dataset<Row> userDF = spark.read().text("E:\\data\\ml-1m\\users.dat");
        Dataset<Row> movieDF = spark.read().text("E:\\data\\ml-1m\\movies.dat");
        Dataset<Row> ratingDF = spark.read().text("E:\\data\\ml-1m\\ratings.dat");


        useStructTypeAddSchemaToDf(spark, userDF, movieDF, ratingDF);

    }

    /**
     * 使用StructType给相应的RDD加上schema信息
     *
     * @param userDF
     * @param movieDF
     * @param ratingDF
     */
    private static void useStructTypeAddSchemaToDf(SparkSession spark, Dataset<Row> userDF, Dataset<Row> movieDF, Dataset<Row> ratingDF) {
        StructType useSchema = DataTypes.createStructType(Arrays.asList(
                DataTypes.createStructField("userId", DataTypes.StringType, false),
                DataTypes.createStructField("gender", DataTypes.StringType, false),
                DataTypes.createStructField("age", DataTypes.IntegerType, false),
                DataTypes.createStructField("occupation", DataTypes.StringType, false),
                DataTypes.createStructField("zipCode", DataTypes.StringType, false)
        ));

        JavaRDD<Row> rowJavaRDD = userDF.javaRDD().map(new Function<Row, Row>() {
            @Override
            public Row call(Row v1) throws Exception {
                String lines = v1.toString();
                String[] strings = lines.split("::");
                return RowFactory.create(strings[0], strings[1], Integer.parseInt(strings[2]), strings[3], strings[4]);
            }
        });
        Dataset<Row> userDataFrame = spark.createDataFrame(rowJavaRDD, useSchema);
        userDataFrame.printSchema();
        userDataFrame.show(50);
    }

    /**
     * 完全使用sql实现某部 电影 观看者 中 男性 和 女性 不同 年龄 分别 有 多少 人。
     *
     * @param spark
     * @param userDataSet
     * @param ratingsDataSet
     * @param movieId
     */
    private static void getFemaleAndMaleSeeMovieCountByOnlySql(SparkSession spark, Dataset<UserInfo> userDataSet, Dataset<RatingsInfo> ratingsDataSet, String movieId) {
        try {
            userDataSet.createTempView("userTable");
            ratingsDataSet.createTempView("ratingsTable");

            String sqlText = "select gender,age,count(*) from userTable u join ratingsTable r on u.userId=r.userId where movieId=" + movieId + " group by gender,age";
            spark.sql(sqlText).show(50);
        } catch (AnalysisException e) {
            e.printStackTrace();
        }
        ratingsDataSet.select("movieId", "rating").groupBy("movieId").avg("rating").orderBy("avg(rating)", "desc").show(10);
    }

    /**
     * 通过spak dataFrame算子实现 某部 电影 观看者 中 男性 和 女性 不同 年龄 分别 有 多少 人。
     *
     * @param spark
     * @param userDataSet
     * @param movieDataSet
     * @param ratingsDataSet
     * @param movieId
     */
    private static void getFemaleAndMaleSeeMoviesCountByDfOp(SparkSession spark,
                                                             Dataset<UserInfo> userDataSet,
                                                             Dataset<MovieInfo> movieDataSet,
                                                             Dataset<RatingsInfo> ratingsDataSet, String movieId) {
        Dataset<RatingsInfo> filterRatingsInfoDataset = ratingsDataSet.filter("movieId=" + movieId);

        Dataset<Row> userId = filterRatingsInfoDataset.join(userDataSet, "userId");

        userId.groupBy(userId.col("gender"), userId.col("age")).count().show(50);


    }

    /**
     * 通过DataFrame 算子+spark sql实现 某部 电影 观看者 中 男性 和 女性 不同 年龄 分别 有 多少 人。
     *
     * @param userDS
     * @param movieDS
     * @param ratingsDS
     */
    private static void getFemaleAndMaleSeeMoviesCountBySparkSql(SparkSession spark, Dataset<UserInfo> userDS,
                                                                 Dataset<MovieInfo> movieDS, Dataset<RatingsInfo> ratingsDS,
                                                                 String movieId) {
        /**
         * 通过userId连接 userDS 和ratingDS
         */

        List<String> userIdList = new ArrayList<String>();
        userIdList.add("userId");

        Seq<String> joinCol = JavaConversions.asScalaBuffer(userIdList).toSeq();
        Dataset<Row> userIdJoinDF = userDS.join(ratingsDS, joinCol);

        try {
            /**
             * 注册一个临时表
             */
            userIdJoinDF.createTempView("userGenderAndRating");
            /**
             * 查询
             */
            Dataset<Row> femaleMovieDF = spark.sql(
                    "select * from userGenderAndRating where movieId=" + movieId);
            /**
             * 通过删除多余的列避免查询产生歧义
             */
//            Dataset<Row> movieNameDF = femaleMovieDF.as("a").join(movieDS.as("b")).where("a.movieId=b.movieId").drop(movieDS.col("movieId"));
            List<String> joinName = new ArrayList<String>();
            joinName.add("movieId");
            Seq<String> uniqJoinName = JavaConversions.asScalaBuffer(joinName).toSeq();
            Dataset<Row> movieNameDF = femaleMovieDF.join(movieDS, uniqJoinName);

            movieNameDF.createTempView("movieNameTable");

            Dataset<Row> rowDataset = spark.sql(
                    "select movieName,age,gender,count(*) as count_id from movieNameTable group by movieName,gender,age order by count_id ");
            rowDataset.show(50);
        } catch (AnalysisException e) {
            e.printStackTrace();
        }
    }
}